from datasets import Dataset


def process_extracted_questions(dataset: Dataset) -> Dataset:
    """
    Processes each example from the input dataset WITHOUT using .map().

    For each input row:
        - We look for multiple "QUESTION: " blocks in the 'question_choices_solutions' column.
        - Each found question/answers/solution is expanded into a separate output row.
        - We preserve all original columns and add new columns:
            * 'extracted_question' (string)
            * 'extracted_answer_choices' (list[str] with at least one element)
            * 'extracted_solution' (string)

    Returns:
        A new Dataset with possibly more rows than the original,
        containing the added columns.
    """

    new_rows = []

    # Iterate over each example in the dataset
    for ex in dataset:
        # Safely get the content from "question_choices_solutions"
        entry = ex.get("question_choices_solutions", "")
        entry = entry.strip() if isinstance(entry, str) else ""

        # If there's no text, create a single row with default values
        if not entry:
            row = dict(ex)  # copy original columns
            row["extracted_question"] = "NO QUESTION DETECTED"
            row["extracted_answer_choices"] = ["NO CHOICES AVAILABLE"]
            row["extracted_solution"] = "NO ANSWER DETECTED"
            new_rows.append(row)
            continue

        # Split the text by the question marker
        sections = entry.split("\nQUESTION: ")
        found_any_question = False

        for section in sections:
            section = section.strip()
            if not section:
                continue  # skip empty splits

            found_any_question = True

            # "parts[0]" is the question, "parts[1]" includes the answer choices + solution
            parts = section.split("\nANSWER CHOICES: ")
            question_text = parts[0].strip() or "NO QUESTION DETECTED"

            if len(parts) > 1:
                # We have some answer choices; split for solution
                answer_solution_parts = parts[1].split("\nSOLUTION: ")
                answer_choices_raw = answer_solution_parts[0].strip()
                solution = (
                    answer_solution_parts[1].strip()
                    if len(answer_solution_parts) > 1
                    else "NO ANSWER DETECTED"
                )
            else:
                # No explicit "ANSWER CHOICES:" block
                answer_choices_raw = "free response"
                solution = "NO ANSWER DETECTED"

            # Always ensure question is a string
            question_text = question_text if question_text else "NO QUESTION DETECTED"

            # Convert answer choices into a list of strings
            if "|||" in answer_choices_raw:
                # Split by delimiter
                answer_choices_list = [
                    choice.strip()
                    for choice in answer_choices_raw.split(" ||| ")
                    if choice.strip()
                ]
                # If the split is empty after stripping, use a default
                if not answer_choices_list:
                    answer_choices_list = ["NO CHOICES AVAILABLE"]
            elif (
                answer_choices_raw.lower() == "free response" or not answer_choices_raw
            ):
                answer_choices_list = ["NO CHOICES AVAILABLE"]
            else:
                answer_choices_list = [answer_choices_raw]

            # Ensure the solution is a non-empty string
            solution = solution if solution else "NO ANSWER DETECTED"

            # Build a new row preserving the original columns
            row = dict(ex)
            row["extracted_question"] = question_text.replace("QUESTION:", "").strip()
            row["extracted_answer_choices"] = answer_choices_list
            row["extracted_solution"] = solution.replace("SOLUTION:", "").strip()
            new_rows.append(row)

        # If we never found any "QUESTION: " in the text,
        # add a single default row
        if not found_any_question:
            row = dict(ex)
            row["extracted_question"] = "NO QUESTION DETECTED"
            row["extracted_answer_choices"] = ["NO CHOICES AVAILABLE"]
            row["extracted_solution"] = "NO ANSWER DETECTED"
            new_rows.append(row)

    # Convert the list of dictionaries into a new Hugging Face Dataset
    new_dataset = Dataset.from_list(new_rows)
    return new_dataset


def process_extracted_questions_unverified(dataset: Dataset) -> Dataset:
    """
    Processes each example from the input dataset WITHOUT using .map().

    For each input row:
        - We look for multiple "QUESTION: " blocks in the 'question_choices_solutions' column.
        - Each found question/answers/solution is expanded into a separate output row.
        - We preserve all original columns and add new columns:
            * 'extracted_question' (string)
            * 'extracted_answer_choices' (list[str] with at least one element)
            * 'extracted_solution' (string)

    Returns:
        A new Dataset with possibly more rows than the original,
        containing the added columns.
    """

    new_rows = []

    # Iterate over each example in the dataset
    for ex in dataset:
        # Safely get the content from "question_choices_solutions"
        entry = ex.get("question_choices_solutions", "")
        entry = entry.strip() if isinstance(entry, str) else ""

        # If there's no text, create a single row with default values
        if not entry:
            row = dict(ex)  # copy original columns
            row["extracted_question"] = "NO QUESTION DETECTED"
            row["extracted_answer_choices"] = ["NO CHOICES AVAILABLE"]
            row["extracted_solution"] = "NO ANSWER DETECTED"
            new_rows.append(row)
            continue

        # Split the text by the question marker
        sections = entry.split("\nQUESTION: ")
        found_any_question = False

        for section in sections:
            section = section.strip()
            if not section:
                continue  # skip empty splits

            found_any_question = True

            # "parts[0]" is the question, "parts[1]" includes the answer choices + solution
            parts = section.split("\nANSWER CHOICES: ")
            question_text = parts[0].strip() or "NO QUESTION DETECTED"

            if len(parts) > 1:
                # We have some answer choices; split for solution
                answer_solution_parts = parts[1].split("\nSOLUTION: ")
                answer_choices_raw = answer_solution_parts[0].strip()
                solution = (
                    answer_solution_parts[1].strip()
                    if len(answer_solution_parts) > 1
                    else "NO ANSWER DETECTED"
                )
            else:
                # No explicit "ANSWER CHOICES:" block
                answer_choices_raw = "free response"
                solution = "NO ANSWER DETECTED"

            # Always ensure question is a string
            question_text = question_text if question_text else "NO QUESTION DETECTED"

            # Convert answer choices into a list of strings
            if "|||" in answer_choices_raw:
                # Split by delimiter
                answer_choices_list = [
                    choice.strip()
                    for choice in answer_choices_raw.split(" ||| ")
                    if choice.strip()
                ]
                # If the split is empty after stripping, use a default
                if not answer_choices_list:
                    answer_choices_list = ["NO CHOICES AVAILABLE"]
            elif (
                answer_choices_raw.lower() == "free response" or not answer_choices_raw
            ):
                answer_choices_list = ["NO CHOICES AVAILABLE"]
            else:
                answer_choices_list = [answer_choices_raw]

            # Ensure the solution is a non-empty string
            solution = solution if solution else "NO ANSWER DETECTED"

            # Build a new row preserving the original columns
            row = dict(ex)
            row["extracted_question"] = question_text.replace("QUESTION:", "").strip()
            row["extracted_answer_choices"] = answer_choices_list
            row["extracted_solution"] = solution.replace("SOLUTION:", "").strip()
            new_rows.append(row)

        # If we never found any "QUESTION: " in the text,
        # add a single default row
        if not found_any_question:
            row = dict(ex)
            row["extracted_question"] = "NO QUESTION DETECTED"
            row["extracted_answer_choices"] = ["NO CHOICES AVAILABLE"]
            row["extracted_solution"] = "NO ANSWER DETECTED"
            new_rows.append(row)

    # Convert the list of dictionaries into a new Hugging Face Dataset
    new_dataset = Dataset.from_list(new_rows)
    return new_dataset


def filter_valid_qas(dataset: Dataset) -> Dataset:
    """
    Filters out rows that don't meet the following condition:
    1) extracted_solution != 'NO ANSWER DETECTED'
    2) extracted_solution not in ['free response answer', '"free response"', 'free response']
    3) extracted_question != 'NO ANSWER DETECTED'
    4) extracted_solution != '"NO ANSWER DETECTED"'

    Args:
        dataset (Dataset): A Hugging Face Dataset with columns 'extracted_question' and 'extracted_solution'.

    Returns:
        Dataset: Filtered dataset containing only the rows that satisfy all conditions.
    """

    # Define the filtering function
    def is_valid(example):
        # Check each condition exactly as provided
        return (
            example["extracted_solution"] != "NO ANSWER DETECTED"
            and example["extracted_solution"]
            not in ["free response answer", '"free response"', "free response"]
            and example["extracted_question"] != "NO ANSWER DETECTED"
            and example["extracted_solution"] != '"NO ANSWER DETECTED"'
        )

    # Use .filter to keep only those rows for which is_valid(example) is True
    filtered_dataset = dataset.filter(is_valid)
    return filtered_dataset


def filter_unverified_valid_qas(dataset: Dataset) -> Dataset:
    """
    Filters out rows that don't meet the following condition:
    1) extracted_solution != 'NO ANSWER DETECTED'
    2) extracted_solution not in ['free response answer', '"free response"', 'free response']
    3) extracted_question != 'NO ANSWER DETECTED'
    4) extracted_solution != '"NO ANSWER DETECTED"'

    Args:
        dataset (Dataset): A Hugging Face Dataset with columns 'extracted_question' and 'extracted_solution'.

    Returns:
        Dataset: Filtered dataset containing only the rows that satisfy all conditions.
    """

    # Define the filtering function
    def is_valid(example):
        # Check each condition exactly as provided
        return example["extracted_question"] != "NO ANSWER DETECTED"

    # Use .filter to keep only those rows for which is_valid(example) is True
    filtered_dataset = dataset.filter(is_valid)
    return filtered_dataset


def filter_bad_pdfs(dataset: Dataset) -> Dataset:
    """
    Filters out rows that don't meet the following condition:
    1) extracted_solution != 'NO ANSWER DETECTED'
    2) extracted_solution not in ['free response answer', '"free response"', 'free response']
    3) extracted_question != 'NO ANSWER DETECTED'
    4) extracted_solution != '"NO ANSWER DETECTED"'

    Args:
        dataset (Dataset): A Hugging Face Dataset with columns 'extracted_question' and 'extracted_solution'.

    Returns:
        Dataset: Filtered dataset containing only the rows that satisfy all conditions.
    """

    # Define the filtering function
    def is_valid(example):
        # Check each condition exactly as provided
        return example["page_count"] > 0

    # Use .filter to keep only those rows for which is_valid(example) is True
    filtered_dataset = dataset.filter(is_valid)
    return filtered_dataset


def filter_questions_with_the_answer_in_them(dataset: Dataset) -> Dataset:
    """
    Args:
        dataset (Dataset): A Hugging Face Dataset with columns 'extracted_question' and 'extracted_solution'.

    Returns:
        Dataset: Filtered dataset containing only the rows that satisfy all conditions.
    """

    # Define the filtering function
    def is_valid(example):
        # Check each condition exactly as provided
        return (
            example["extracted_solution"].replace('"', "").replace("'", "").lower()
            not in example["improved_question_solution"].lower()
        )

    # Use .filter to keep only those rows for which is_valid(example) is True
    filtered_dataset = dataset.filter(is_valid)
    return filtered_dataset
